# MIT License
#
# Copyright (C) The Adversarial Robustness Toolbox (ART) Authors 2021
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
# documentation files (the "Software"), to deal in the Software without restriction, including without limitation the
# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit
# persons to whom the Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all copies or substantial portions of the
# Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE
# WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
"""
This module implements the following white-box attacks related to PE malware crafting:
    1) Append based attacks (example paper link: https://arxiv.org/abs/1810.08280)
    2) Section insertion attacks (example paper link: https://arxiv.org/abs/2008.07125)
    3) Slack manipulation attacks (example paper link: https://arxiv.org/abs/1810.08280)
    4) DOS Header Attacks (example paper link: https://arxiv.org/abs/1901.03583)
"""

import json
import random
import logging
from typing import Optional, Union, Tuple, List, Dict, TYPE_CHECKING
from tqdm.auto import trange

import numpy as np

from art.estimators.estimator import BaseEstimator, NeuralNetworkMixin
from art.estimators.classification.classifier import ClassifierMixin
from art.attacks.attack import EvasionAttack

if TYPE_CHECKING:
    # pylint: disable=C0412
    import tensorflow as tf
    from art.utils import CLASSIFIER_NEURALNETWORK_TYPE

logger = logging.getLogger(__name__)


class MalwareGDTensorFlow(EvasionAttack):
    """
    Implementation of the following white-box attacks related to PE malware crafting:
        1) Append based attacks (example paper link: https://arxiv.org/abs/1810.08280)
        2) Section insertion attacks (example paper link: https://arxiv.org/abs/2008.07125)
        3) Slack manipulation attacks (example paper link: https://arxiv.org/abs/1810.08280)
        4) DOS Header Attacks (example paper link: https://arxiv.org/abs/1901.03583)
    """

    attack_params = EvasionAttack.attack_params + [
        "embedding_weights",
        "param_dic",
        "num_of_iterations",
        "l_0",
        "l_r",
        "use_sign",
        "verbose",
    ]

    _estimator_requirements = (BaseEstimator, NeuralNetworkMixin, ClassifierMixin)

    def __init__(
        self,
        classifier: "CLASSIFIER_NEURALNETWORK_TYPE",
        embedding_weights: np.ndarray,
        param_dic: Dict[str, int],
        num_of_iterations: int = 10,
        l_0: Union[float, int] = 0.1,
        l_r: float = 1.0,
        use_sign: bool = False,
        verbose: bool = False,
    ) -> None:
        """
        :param classifier: A trained classifier that takes in the PE embeddings to make a prediction.
        :param embedding_weights: Weights for the embedding layer
        :param param_dic: A dictionary specifying some MalConv parameters.
                          'maxlen': the input size to the MalConv model
                          'input_dim': the number of discrete values, normally 257.
                          'embedding_size': size of the embedding layer. Default 8.
        :param num_of_iterations: The number of iterations to apply.
        :param l_0: l_0 bound for the attack. If less than 1 it is interpreted as a fraction of the file size.
                    If larger than 1 it is interpreted as the total number of permissible features to change.
        :param l_r: Learning rate for the optimisation
        :param use_sign: If we want to use the sign of the gradient, rather than the gradient itself.
        :param verbose: Show progress bars.
        """
        super().__init__(estimator=classifier)
        self.param_dic = param_dic
        self.embedding_weights = embedding_weights
        self.l_0 = l_0
        self.l_r = l_r
        self.use_sign = use_sign
        self.total_perturbation: np.ndarray = np.zeros(shape=(1, 1))
        self.num_of_iterations = num_of_iterations
        self.verbose = verbose
        self._check_params()

        self.embedding_weights = self.embedding_weights.astype("float32")

    def _check_params(self) -> None:
        if not isinstance(self.param_dic, dict):
            raise ValueError(
                "A param_dic should be provided with the following keys/value pairs: "
                "'maxlen': the input size to the MalConv model"
                "'input_dim': the number of discrete values. Normally 257."
                "'embedding_size': size of the embedding layer. Normally 8."
            )

        if not isinstance(self.embedding_weights, np.ndarray):
            raise ValueError("The weights for the embedding layer should be given as a numpy array.")

        if not isinstance(self.l_0, (int, float)) or self.l_0 < 0:
            raise ValueError(
                "The l0 bound should be greater or equal to 0. Further, it should be provided as an "
                "integer specifying the total number of features to perturb, or a a float representing "
                "the fraction of the total number of features of the original file we can perturb."
            )

        if not isinstance(self.l_r, (int, float)) or self.l_r < 0:
            raise ValueError("The learning rate should be a float or integer greater than zero.")

        if not isinstance(self.use_sign, bool):
            raise ValueError("Whether to use the sign of the gradient should be a True/False bool.")

        if not isinstance(self.num_of_iterations, int) or self.num_of_iterations < 0:
            raise ValueError("The number of iterations must be an integer greater than zero.")

        if not isinstance(self.verbose, bool):
            raise ValueError("The verbosity level should be a True/False bool.")

    @staticmethod
    def initialise_sample(
        x: np.ndarray,
        y: np.ndarray,
        sample_sizes: np.ndarray,
        perturbation_size: np.ndarray,
        perturb_sizes: Optional[List[List[int]]],
        perturb_starts: Optional[List[List[int]]],
    ) -> np.ndarray:
        """
        Randomly append bytes at the end of the malware to initialise it, or if perturbation regions are provided,
        perturb those.

        :param x: Array with input data.
        :param y: Labels, after having been adjusted to account for malware which cannot support the full l0 budget.
        :param sample_sizes: The size of the original file, before it was padded to the input size required by MalConv
        :param perturbation_size: Size of the perturbations in L0 terms to put at end of file
        :param perturb_sizes: List of length batch size, each element is in itself a list containing the size
                              of the allowable perturbation region
        :param perturb_starts: List of length batch size, each element is in itself a list containing the start
                               of perturbation region.
        :return x: Array with features to be perturbed set to a random value.
        """
        for j in range(len(x)):
            if y[j] == 1:
                if perturb_sizes is not None and perturb_starts is not None:
                    for size, start in zip(perturb_sizes[j], perturb_starts[j]):
                        x[j, start : start + size] = np.random.randint(low=0, high=256, size=(1, size))
                x[j, sample_sizes[j] : sample_sizes[j] + perturbation_size[j]] = np.random.randint(
                    low=0, high=256, size=(1, perturbation_size[j])
                )

        return x

    def check_valid_size(
        self,
        y: np.ndarray,
        sample_sizes: np.ndarray,
        append_perturbation_size: np.ndarray,
    ) -> np.ndarray:
        """
        Checks that we can append the l0 perturbation to the malware sample and not exceed the
        maximum file size. A new label vector with just the valid files indicated is created.

        :param y: Labels.
        :param sample_sizes: The size of the original file, before it was padded to the input size required by MalConv.
        :param append_perturbation_size: Size of the perturbations in L0 terms to put at end of file.
        :return adv_label_vector: Labels which indicate which malware samples have enough free features to
                                  accommodate all the adversarial perturbation.
        """

        adv_label_vector = np.zeros_like(y)
        for i, label in enumerate(y):
            if label == 1:
                if sample_sizes[i] + append_perturbation_size[i] <= self.param_dic["maxlen"]:
                    adv_label_vector[i] = 1
                    logger.info("size to append on sample %d is %d", i, append_perturbation_size[i])

        return adv_label_vector

    def generate_mask(
        self,
        x: np.ndarray,
        y: np.ndarray,
        sample_sizes: np.ndarray,
        perturbation_size: np.ndarray,
        perturb_sizes: Optional[List[List[int]]],
        perturb_starts: Optional[List[List[int]]],
    ) -> "tf.Tensor":
        """
        Makes a mask to apply to the gradients to control which samples in the batch are perturbed.

        :param x: Array with input data.
        :param y: Labels to make sure the benign files are zero masked.
        :param sample_sizes: The size of the original file, before it was padded to the input size required by MalConv
        :param perturbation_size: Size of the perturbations in L0 terms to put at end of file
        :param perturb_sizes: List of length batch size, each element is in itself a list containing the size
                              of the allowable perturbation region
        :param perturb_starts: List of length batch size, each element is in itself a list containing the start
                               of perturbation region.
        :return mask: Array with 1s on the features we will modify on this batch and 0s elsewhere.
        """
        import tensorflow as tf

        mask = np.zeros_like(x)
        for i in range(len(x)):
            if y[i] == 1:
                # if no section information was provided, append perturbation at the end of the file.
                if perturb_sizes is None and perturb_starts is None:
                    mask[i, sample_sizes[i] : sample_sizes[i] + perturbation_size[i]] = 1
                elif perturb_sizes is not None and perturb_starts is not None:
                    sample_perturb_sizes = perturb_sizes[i]
                    sample_perturb_starts = perturb_starts[i]

                    for size, start in zip(sample_perturb_sizes, sample_perturb_starts):
                        mask[i, start : start + size] = 1
                    mask[i, sample_sizes[i] : sample_sizes[i] + perturbation_size[i]] = 1
                else:  # pragma: no cover
                    raise ValueError(
                        "either both start and size of perturbation regions are supplied or neither is supplied"
                    )

                # sanity check that the total number of features marked on the mask
                # does not exceed the total perturbation budget
                assert np.sum(mask[i]) == self.total_perturbation[i]
        mask = np.expand_dims(mask, axis=-1)
        expanded_masks = []
        for _ in range(self.param_dic["embedding_size"]):
            expanded_masks.append(mask)
        # repeat it so that it matches the 8 dimensional embedding layer
        expanded_masks = np.concatenate(expanded_masks, axis=-1)
        expanded_masks = tf.convert_to_tensor(expanded_masks)
        expanded_masks = tf.cast(expanded_masks, dtype="float32")  # pylint: disable=E1123 disable=E1120
        return expanded_masks

    def update_embeddings(self, embeddings: "tf.Tensor", gradients: "tf.Tensor", mask: "tf.Tensor") -> "tf.Tensor":
        """
        Update embeddings.

        :param embeddings: Embeddings produced by the data from passing it through the first embedding layer of MalConv
        :param gradients: Gradients to update the embeddings
        :param mask: Tensor with 1s on the embeddings we modify, 0s elsewhere.
        :return embeddings: Updated embeddings wrt the adversarial objective.
        """
        import tensorflow as tf

        if self.use_sign:
            gradients = tf.sign(gradients)
        embeddings = embeddings + self.l_r * gradients * mask
        return embeddings

    def get_adv_malware(
        self,
        embeddings: "tf.Tensor",
        data: np.ndarray,
        labels: np.ndarray,
        fsize: np.ndarray,
        perturbation_size: np.ndarray,
        perturb_sizes: Optional[List[List[int]]] = None,
        perturb_starts: Optional[List[List[int]]] = None,
    ) -> np.ndarray:
        """
        Project the adversarial example back though the closest l2 vector.

        :embeddings: Adversarially optimised embeddings
        :labels: Labels for the data
        :fsize: Size of the original malware
        :data: Original data in the feature space
        :perturbation_size: Size of the l0 attack to append (if any).
        :perturb_sizes: List, with each element itself being a list of the start positions of a
                        perturbation regions in a sample
        :perturb_starts: List, with each element itself being a list of the start positions of a
                         start of the perturbation regions in a sample

        :return data: Numpy array with valid data samples.
        """
        import tensorflow as tf

        for i, label in enumerate(labels):
            if label == 1:
                total_diff = 0
                m = tf.constant([1, self.param_dic["input_dim"], 1], tf.int32)

                if perturb_sizes is not None and perturb_starts is not None:
                    for size, start in zip(perturb_sizes[i], perturb_starts[i]):
                        expanded = tf.tile(tf.expand_dims(embeddings[i, start : start + size, :], axis=1), m)
                        diff = tf.norm((expanded - self.embedding_weights), axis=-1)
                        diff = tf.math.argmin(diff, axis=-1)
                        data[i, start : start + size] = diff
                        total_diff += len(diff)
                expanded = tf.tile(
                    tf.expand_dims(embeddings[i, fsize[i] : fsize[i] + perturbation_size[i], :], axis=1), m
                )
                diff = tf.norm((expanded - self.embedding_weights), axis=-1)
                diff = tf.math.argmin(diff, axis=-1)
                data[i, fsize[i] : fsize[i] + perturbation_size[i]] = diff
                total_diff += len(diff)

                # sanity check that the total number of perturbed features does
                # not exceed the total perturbation budget
                assert total_diff == self.total_perturbation[i]
        return data

    @staticmethod
    def pull_out_adversarial_malware(
        x: np.ndarray,
        y: np.ndarray,
        sample_sizes: np.ndarray,
        initial_dtype: np.dtype,
        input_perturb_sizes: Optional[List[List[int]]] = None,
        input_perturb_starts: Optional[List[List[int]]] = None,
    ) -> Union[
        Tuple[np.ndarray, np.ndarray, np.ndarray],
        Tuple[np.ndarray, np.ndarray, np.ndarray, List[List[int]], List[List[int]]],
    ]:
        """
        Fetches the malware from the data

        :param x: Batch of data which will contain a mix of adversarial examples and unperturbed data.
        :param y: Labels indicating which are valid adversarial examples or not.
        :param initial_dtype: Data can be given in a few formats (uin16, float, etc.) so use initial_dtype
                              to make the returned sample match the original.
        :param sample_sizes: Size of the original data files
        :param input_perturb_sizes: List of length batch size, each element is in itself a list containing
                                    the size of the allowable perturbation region
        :param input_perturb_starts: List of length batch size, each element is in itself a list containing
                                     the start of perturbation region.

        :return adv_x: array composed of only the data that we can make valid adversarial examples from.
        :return adv_y: labels, all ones.
        """
        num_of_malware_samples = int(np.sum(y))

        # make array and allocate, much faster than appending to list and converting
        adv_x = np.zeros((num_of_malware_samples, x.shape[1]), dtype=initial_dtype)
        adv_y = np.ones((num_of_malware_samples, 1))

        adv_sample_sizes = np.zeros((num_of_malware_samples,), dtype=int)
        perturb_sizes = []
        perturb_starts = []

        j = 0
        for i, label in enumerate(y):
            if label == 1:
                adv_x[j] = x[i]
                adv_sample_sizes[j] = int(sample_sizes[i])
                j += 1
                if input_perturb_sizes is not None and input_perturb_starts is not None:
                    perturb_sizes.append(input_perturb_sizes[i])
                    perturb_starts.append(input_perturb_starts[i])

        if input_perturb_sizes is not None and input_perturb_starts is not None:
            return adv_x, adv_y, adv_sample_sizes, perturb_starts, perturb_sizes

        return adv_x, adv_y, adv_sample_sizes

    def compute_perturbation_regions(
        self,
        input_perturbation_size: np.ndarray,
        input_perturb_sizes: List[List[int]],
        automatically_append: bool,
    ) -> Tuple[np.ndarray, List[List[int]]]:
        """
        Based on the l0 budget and the provided allowable perturbation regions we iteratively mark regions of the PE
        file for modification until we exhaust our budget.

        :param input_perturb_sizes: The size of the regions we can perturb.
        :param input_perturbation_size: The total amount of perturbation allowed on a specific sample.
        :param automatically_append: If we want to automatically append unused perturbation on the end of the malware.
        :return perturbation_size: Remaining perturbation (if any)
        :return perturb_sizes: Potentially adjusted sizes of the locations in the PE file we can perturb.
        """
        perturb_sizes = input_perturb_sizes.copy()
        perturbation_size = input_perturbation_size.copy()

        # reduce the perturbation we add at the end if we have perturbation to add in the malware itself and reduce the
        # sizes given in input_perturb_sizes if it exceeds the total input_perturbation_size
        for i, section_sizes in enumerate(perturb_sizes):
            for j, size in enumerate(section_sizes):
                if perturbation_size[i] - size >= 0:
                    logger.info("on sample %d allocate %d in perturb region", i, size)
                    perturbation_size[i] = perturbation_size[i] - size
                else:  # run out of l0 budget.
                    excess = np.abs(perturbation_size[i] - size)  # amount of overspill
                    perturbation_size[i] = perturbation_size[i] - size
                    section_sizes[j] = size - excess
                    # update the perturb sizes

                    logger.info("on sample %d ran out of l0, update to %d from %d", i, section_sizes[j], size)

                    perturb_sizes[i] = section_sizes
                    perturbation_size[i] = 0
        perturbation_size = np.where(perturbation_size < 0, 0, perturbation_size)

        if not automatically_append:
            # if we do not want to automatically append then set the remaining perturbation to 0
            perturbation_size = np.zeros_like(perturbation_size)
            # total_perturbation is now sample dependant so reassign.
            total_perturbation = np.zeros_like(perturbation_size)
            for i in range(len(perturbation_size)):
                total_perturbation[i] = np.sum(perturb_sizes[i])
            self.total_perturbation = total_perturbation
        return perturbation_size, perturb_sizes

    def pull_out_valid_samples(
        self,
        x: np.ndarray,
        y: np.ndarray,
        sample_sizes: np.ndarray,
        automatically_append: bool = True,
        perturb_sizes: Optional[List[List[int]]] = None,
        perturb_starts: Optional[List[List[int]]] = None,
    ) -> Union[
        Tuple[np.ndarray, np.ndarray, np.ndarray],
        Tuple[np.ndarray, np.ndarray, np.ndarray, List[List[int]], List[List[int]]],
    ]:
        """
        Filters the input data for samples that can be made adversarial.

        :param x: Array with input data.
        :param y: Labels to make sure the benign files are zero masked.
        :param sample_sizes: The size of the original file, before it was padded to the input size required by MalConv
        :param automatically_append: Whether to automatically append extra spare perturbation at the end of the file.
        :param perturb_sizes: List of length batch size, each element is in itself a list containing
                              the size of the allowable perturbation region
        :param perturb_starts: List of length batch size, each element is in itself a list containing
                               the start of perturbation region.

        """
        initial_dtype = x.dtype

        perturbation_size = np.zeros(len(sample_sizes), dtype=int)
        for i, sample_size in enumerate(sample_sizes):
            if self.l_0 < 1:  # l0 is a fraction of the filesize
                perturbation_size[i] = int(sample_size * self.l_0)
            else:  # or l0 is interpreted as total perturbation size
                perturbation_size[i] = int(self.l_0)

        if perturb_sizes is not None and perturb_starts is not None:
            perturbation_size, perturb_sizes = self.compute_perturbation_regions(
                perturbation_size, perturb_sizes, automatically_append
            )

        y = self.check_valid_size(y, sample_sizes, perturbation_size)
        if perturb_sizes is not None and perturb_starts is not None:
            return self.pull_out_adversarial_malware(x, y, sample_sizes, initial_dtype, perturb_sizes, perturb_starts)

        return self.pull_out_adversarial_malware(x, y, sample_sizes, initial_dtype)

    def generate(  # pylint: disable=W0221
        self,
        x: np.ndarray,
        y: Optional[np.ndarray] = None,
        sample_sizes: Optional[np.ndarray] = None,
        automatically_append: bool = True,
        verify_input_data: bool = True,
        perturb_sizes: Optional[List[List[int]]] = None,
        perturb_starts: Optional[List[List[int]]] = None,
        **kwargs,
    ) -> np.ndarray:
        """
        Generates the adversarial examples. x needs to be composed of valid files by default which can support the
        adversarial perturbation and so are malicious and can support the assigned L0 budget. They can be obtained by
        using `pull_out_valid_samples` on the data.

        This check on the input data can be over-ridden by toggling the flag verify_input_data
        This will result in only the data which can be made adversarial being perturbed and so the resulting batch will
        be a mixture of adversarial and unperturbed data.

        To assign the L0 budget we go through each list in perturb_sizes and perturb_starts in order, and
        assign the budget based on the sizes given until the l0 budget is exhausted.

        After all the regions marked in perturb_sizes and perturb_starts have been assigned and automatically_append is
        set to true and remaining l0 perturbation the extra perturbation is added at the end in an append style attack.

        :param x: An array with input data.
        :param y: (N, 1) binary labels to make sure the benign files are zero masked.
        :param sample_sizes: The size of the original file, before it was padded to the input size required by MalConv
        :param automatically_append: Whether to automatically append extra spare perturbation at the end of the file.
        :param verify_input_data: If to check that all the data supplied is valid for adversarial perturbations.
        :param perturb_sizes: A list of length batch size, each element is in itself a list containing
                              the size of the allowable perturbation region
        :param perturb_starts: A list of length batch size, each element is in itself a list containing
                               the start of perturbation region.
        :return x: our adversarial examples.
        """
        import tensorflow as tf

        # make copy so original data is not modified.
        adv_x = x.copy()
        if sample_sizes is None:  # pragma: no cover
            raise ValueError("The size of the original files needs to be supplied")
        if y is None:  # pragma: no cover
            raise ValueError("Labels need to be provided so we only modify the malware")

        # check that the dimensions all match
        assert len(adv_x) == len(y)
        assert len(y) == len(sample_sizes)
        if perturb_sizes is not None:
            assert len(y) == len(perturb_sizes)
        if perturb_starts is not None:
            assert len(y) == len(perturb_starts)

        # check that if perturb_starts is provided perturb_sizes is also provided and vise versa
        if perturb_starts is not None:
            assert perturb_sizes is not None
        if perturb_sizes is not None:
            assert perturb_starts is not None

        # if we do not automatically append then make sure that we have supplied
        # start and end positions for the perturbation.
        if not automatically_append:
            assert perturb_sizes is not None
            assert perturb_starts is not None

        perturbation_size = np.zeros(len(sample_sizes), dtype=int)
        for i, sample_size in enumerate(sample_sizes):
            if self.l_0 < 1:  # l0 is a fraction of the filesize
                perturbation_size[i] = int(sample_size * self.l_0)
            else:  # or l0 is interpreted as total perturbation size
                perturbation_size[i] = int(self.l_0)
        self.total_perturbation = np.copy(perturbation_size)

        if perturb_sizes is not None and perturb_starts is not None:
            perturbation_size, perturb_sizes = self.compute_perturbation_regions(
                perturbation_size, perturb_sizes, automatically_append
            )

        y = self.check_valid_size(y, sample_sizes, perturbation_size)

        if verify_input_data:
            if np.sum(y) != len(y):
                raise ValueError(  # pragma: no cover
                    f"{len(y) - np.sum(y)} invalid samples found in batch which cannot support the assigned "
                    f"perturbation or are benign To filter for samples that can be processed use "
                    f"pull_out_valid_samples on the samples. Checking can be disabled by using verify_input_data"
                )

        adv_x = self.initialise_sample(
            adv_x, y, sample_sizes, perturbation_size, perturb_sizes=perturb_sizes, perturb_starts=perturb_starts
        )

        mask = self.generate_mask(
            adv_x, y, sample_sizes, perturbation_size, perturb_sizes=perturb_sizes, perturb_starts=perturb_starts
        )

        embeddings = tf.nn.embedding_lookup(params=self.embedding_weights, ids=adv_x.astype("int32"))

        for _ in trange(self.num_of_iterations, desc="PE Adv. Malware", disable=not self.verbose):
            gradients = self.estimator.class_gradient(embeddings, label=0)
            # go from (bsize x 1 x features x embedding size) -> (bsize x features x embedding size) in a
            # framework-agnostic manner.
            gradients = gradients[:, 0, :, :]
            gradients = -1 * gradients
            embeddings = self.update_embeddings(embeddings, gradients, mask)

        adv_x = self.get_adv_malware(
            embeddings=embeddings,
            data=adv_x,
            labels=y,
            fsize=sample_sizes,
            perturbation_size=perturbation_size,
            perturb_sizes=perturb_sizes,
            perturb_starts=perturb_starts,
        )

        return adv_x

    @staticmethod
    def process_file(
        filepath: str, padding_char: int = 256, maxlen: int = 2 ** 20
    ) -> Tuple[np.ndarray, int]:  # pragma: no cover

        """
        Go from raw file to numpy array.

        :param filepath: Path to the file we convert to a numpy array
        :param padding_char: The char to use to pad the input if it is shorter than maxlen
        :param maxlen: Maximum size of the file processed by the model. Currently set to 1MB
        :return data: A numpy array of the PE file
        :return size_of_original_file: Size of the PE file
        """
        with open(filepath, "rb") as file:
            open_file = file.read()
        size_of_original_file = len(open_file)

        data = np.ones((maxlen,), dtype=np.uint16) * padding_char
        selected_bytes = np.frombuffer(open_file[:maxlen], dtype=np.uint8)
        data[: len(selected_bytes)] = selected_bytes

        return data, size_of_original_file

    @staticmethod
    def get_peinfo(
        filepath: str, save_to_json_path: Optional[str] = None
    ) -> Tuple[List[int], List[int]]:  # pragma: no cover
        """
        Given a PE file we extract out the section information to determine the slack regions in the file.
        We return two lists 1) with the start location of the slack regions and 2) with the size of the slack region.
        We are using the lief library (https://github.com/lief-project/LIEF) to manipulate the PE file.

        :param filepath: Path to file we want to analyse with pedump and get the section information.
        :param save_to_json_path: (Optional) if we want to save the results of pedump to a json file, provide the path.
        :return start_of_slack: A list with the slack starts
        :return size_of_slack: A list with the slack start positions
        """
        import lief  # pylint: disable=C0415

        start_of_slack = []
        size_of_slack = []

        cleaned_dump = {}

        binary_load = lief.parse(filepath)  # pylint: disable=I1101
        if binary_load is not None:
            binary = binary_load
        else:
            raise ValueError("Failed to load binary.")

        for section in binary.sections:
            section_info = {}
            slack = section.sizeof_raw_data - section.virtual_size  # type: ignore
            section_info["PointerToRawData"] = section.pointerto_raw_data  # type: ignore
            section_info["VirtualAddress"] = section.virtual_size  # type: ignore
            section_info["SizeOfRawData"] = section.sizeof_raw_data  # type: ignore
            cleaned_dump[section.name] = section_info
            if slack > 0:
                size_of_slack.append(slack)
                start_of_slack.append(section.pointerto_raw_data + section.virtual_size)  # type: ignore

        if save_to_json_path is not None:
            with open(save_to_json_path, "w", encoding="utf8") as outfile:
                json.dump(cleaned_dump, outfile, indent=4, sort_keys=True)

        return start_of_slack, size_of_slack

    def insert_section(
        self,
        datapoint: Union[List[int], str],
        sample_size: Optional[int] = None,
        padding_char: int = 256,
        maxlen: int = 2 ** 20,
        bytes_to_assign: Optional[int] = None,
        verbose: bool = False,
    ) -> Union[
        Tuple[np.ndarray, int, int, int, List[int], List[int]], Tuple[None, None, None, None, None, None]
    ]:  # pragma: no cover
        """
        Create a new section in a PE file that the attacker can perturb to create an adversarial example.
        we are using the lief library (https://github.com/lief-project/LIEF) to manipulate the PE file.

        :param datapoint: either 1) path to file we want to analyse with lief and get the section information.
                          or 2) list of ints that can be processed by lief.

                          If we have already pre-processed the file into a numpy array, we convert it to a form
                          that can be read by lief.
                          eg, if we have it as a numpy array this could be done by:

                          datapoint = datapoint[0:size]  # size is the original size of the malware file
                          datapoint = datapoint.astype('uint8')
                          datapoint = datapoint.tolist()
        :param sample_size: Size of the original datapoint. Only if it is an array and the l0 budget is fractional
        :param padding_char: The char to use to pad the file to be of length maxlen
        :param maxlen: Maximum length of the data that the MalConv model can process
        :param bytes_to_assign: (Optional) how many bytes we wish to specify when inserting a new section.
                                If unspecified the whole l0 budget will be used on a single section.
        :param verbose: lief outputs a lot to the console, particularly if we are processing many files.
                        By default, suppress printing of messages. Can be toggled on/off by True/False
        :return manipulated_data: Executable with section inserted and turned into a numpy array of
                                  the appropriate size
        :return len(manipulated_file): Size of original file
        :return information_on_section.pointerto_raw_data: The start of the inserted section
        :return information_on_section.virtual_size: Size of the inserted section
        :return size_of_slack: Size of slack regions in this executable (including from the section we just inserted)
        :return start_of_slack: Start of slack regions in this executable (including from the section we just inserted)
        """
        # pylint: disable=I1101
        import lief  # pylint: disable=C0415

        if not verbose:
            lief.logging.disable()

        binary_parse = lief.PE.parse(datapoint)
        if binary_parse is not None:
            binary = binary_parse
        else:
            raise ValueError("Failed to load binary.")

        name_in_use = True
        while name_in_use:
            new_section_name = "." + "".join(chr(random.randrange(ord("a"), ord("z"))) for _ in range(5))
            name_in_use = False

            # check if the random name is already by chance in use
            for section in binary.sections:
                if new_section_name == section.name:
                    name_in_use = True

        new_section = lief.PE.Section(new_section_name)

        if bytes_to_assign is None:
            if self.l_0 < 1:  # l0 is a fraction of the file size
                # if it's a filepath we need to get the file size
                if isinstance(datapoint, str):
                    with open(datapoint, "rb") as file:
                        open_file = file.read()
                    sample_size = len(open_file)
                else:
                    if sample_size is None:
                        raise ValueError(
                            "if the data is an array and the l0 budget is fractional "
                            "the sample size must be provided"
                        )

                perturbation_size = int(sample_size * self.l_0)
            else:  # or l0 is interpreted as total perturbation size
                perturbation_size = int(self.l_0)
            new_section.content = [random.randint(0, 255) for _ in range(perturbation_size)]  # type: ignore
        else:
            new_section.content = [random.randint(0, 255) for _ in range(bytes_to_assign)]  # type: ignore

        # we add the new section at the end of the existing sections
        section_end_points = []
        for section in binary.sections:
            section_end_points.append(section.virtual_address + section.size)

        new_section.virtual_address = max(section_end_points)

        binary.add_section(
            new_section,
            random.choice(
                [
                    lief.PE.SECTION_TYPES.BSS,
                    lief.PE.SECTION_TYPES.DATA,
                    lief.PE.SECTION_TYPES.EXPORT,
                    lief.PE.SECTION_TYPES.IDATA,
                    lief.PE.SECTION_TYPES.RELOCATION,
                    lief.PE.SECTION_TYPES.RESOURCE,
                    lief.PE.SECTION_TYPES.TEXT,
                    lief.PE.SECTION_TYPES.TLS_,
                    lief.PE.SECTION_TYPES.UNKNOWN,
                ]
            ),
        )

        information_on_section = binary.get_section(new_section_name)

        size_of_slack = []
        start_of_slack = []
        for section in binary.sections:
            slack = section.sizeof_raw_data - section.virtual_size
            if slack > 0:
                size_of_slack.append(slack)
                start_of_slack.append(section.pointerto_raw_data + section.virtual_size)

        builder = lief.PE.Builder(binary)
        builder.build()

        manipulated_file = np.array(builder.get_build(), dtype=np.uint8)

        manipulated_data = np.ones((maxlen,), dtype=np.uint16) * padding_char

        # Only process files which are less than the max file size supported
        if len(manipulated_file) < maxlen:
            manipulated_data[: len(manipulated_file)] = manipulated_file[:maxlen]

            return (
                manipulated_data,
                len(manipulated_file),
                information_on_section.pointerto_raw_data,
                information_on_section.virtual_size,
                size_of_slack,
                start_of_slack,
            )
        return None, None, None, None, None, None

    @staticmethod
    def get_dos_locations(x: np.ndarray) -> Tuple[List[List[int]], List[List[int]]]:
        """
        We identify the regions in the DOS header which we can perturb adversarially.

        There are a series of "magic numbers" in this method which relate to the structure of the PE file.
        1) mz_offset = 2: the first two bytes of a PE are fixed as MZ.
        2) 0x3C: offset to the pointer to the PE header. The pointer is 4 bytes long.
        3) 0x40: end of the pointer to the PE header.

        :return batch_of_starts: A list of start locations we can perturb.
                                 This will always have the same value of 2 and 64.
        :return batch_of_sizes: Size of the perturbations we can carry out.
        :return batch_of_starts: Start locations which we can perturb.
        """
        batch_of_starts = []
        batch_of_sizes = []
        mz_offset = 2

        for i in range(len(x)):
            size = []
            start = []

            pointer_to_pe_header = x[i, int(0x3C) : int(0x40)].astype(np.uint8)

            # combine 4 unint8 bytes into unit32
            pointer_to_pe_header = (
                pointer_to_pe_header[3] << 24
                | pointer_to_pe_header[2] << 16
                | pointer_to_pe_header[1] << 8
                | pointer_to_pe_header[0]
            )

            size.append(int(0x3C) - mz_offset)
            start.append(mz_offset)

            size.append(int(pointer_to_pe_header) - int(0x40) - 1)
            start.append(int(0x40))

            batch_of_starts.append(start)
            batch_of_sizes.append(size)

        return batch_of_starts, batch_of_sizes
